数据集处理
基础操作
TensorDataset
与 DataLoader
import torch
import torch.utils.data
# 创建一个tensor输入
inputs = torch.tensor([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0], [5.0, 6.0],
[6.0, 7.0], [7.0, 8.0], [8.0, 9.0], [9.0, 10.0], [10.0, 11.0]])
# 创建一个tensor标签
labels = torch.tensor([0, 1, 0, 1, 0, 1, 0, 1, 0, 1])
# 创建TensorDataset对象
dataset = torch.utils.data.TensorDataset(inputs, labels) # 利用torch.utils.data.TensorDataset()
print(dataset)
print(type(dataset))
for data, i in dataset:
print(data, i)
print()
# 创建Dataloader对象
dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True) # 利用torch.utils.
print(dataloader)
print(type(dataloader))
for batch_data, batch_label in dataloader:
print(batch_data, batch_label)
输出
<torch.utils.data.dataset.TensorDataset object at 0x7ff28b5da220>
<class 'torch.utils.data.dataset.TensorDataset'>
tensor([1., 2.]) tensor(0)
tensor([2., 3.]) tensor(1)
tensor([3., 4.]) tensor(0)
tensor([4., 5.]) tensor(1)
tensor([5., 6.]) tensor(0)
tensor([6., 7.]) tensor(1)
tensor([7., 8.]) tensor(0)
tensor([8., 9.]) tensor(1)
tensor([ 9., 10.]) tensor(0)
tensor([10., 11.]) tensor(1)
<torch.utils.data.dataloader.DataLoader object at 0x7ff28b5dabb0>
<class 'torch.utils.data.dataloader.DataLoader'>
tensor([[4., 5.],
[6., 7.]]) tensor([1, 1])
tensor([[5., 6.],
[2., 3.]]) tensor([0, 1])
tensor([[ 7., 8.],
[10., 11.]]) tensor([0, 1])
tensor([[1., 2.],
[8., 9.]]) tensor([0, 1])
tensor([[ 9., 10.],
[ 3., 4.]]) tensor([0, 0])
经典数据集导入
MNIST
内容 | 链接 | |
---|---|---|
CNN
https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
内容 | 链接 | |
---|---|---|
神经网络搭建
LeNet-5
内容 | 链接 | 创作者 |
---|---|---|
Python Class Tutorial | 链接🔗 | Corey Schafer |